/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.dubbo.rpc.protocol.grpc;

import org.apache.dubbo.common.URL;
import org.apache.dubbo.config.ReferenceConfigBase;
import org.apache.dubbo.rpc.Invoker;
import org.apache.dubbo.rpc.ProtocolServer;
import org.apache.dubbo.rpc.RpcException;
import org.apache.dubbo.rpc.model.ApplicationModel;
import org.apache.dubbo.rpc.model.ProviderModel;
import org.apache.dubbo.rpc.model.ServiceRepository;
import org.apache.dubbo.rpc.protocol.AbstractProxyProtocol;

import io.grpc.BindableService;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ManagedChannel;
import io.grpc.Server;
import io.grpc.netty.NettyServerBuilder;

import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

/**
 *
 */
public class GrpcProtocol extends AbstractProxyProtocol {

    public static final int DEFAULT_PORT = 50051;

    /* <address, gRPC channels> */
    private final ConcurrentMap<String, ReferenceCountManagedChannel> channelMap = new ConcurrentHashMap<>();
    private final Object lock = new Object();

    @Override
    protected <T> Runnable doExport(T proxiedImpl, Class<T> type, URL url) throws RpcException {
        String key = url.getAddress();
        ProtocolServer protocolServer = serverMap.computeIfAbsent(key, k -> {
            DubboHandlerRegistry registry = new DubboHandlerRegistry();

            NettyServerBuilder builder =
                    NettyServerBuilder
                    .forPort(url.getPort())
                            .fallbackHandlerRegistry(registry);

            Server originalServer = GrpcOptionsUtils.buildServerBuilder(url, builder).build();
            GrpcRemotingServer remotingServer = new GrpcRemotingServer(originalServer, registry);
            return new ProxyProtocolServer(remotingServer);
        });

        GrpcRemotingServer grpcServer = (GrpcRemotingServer) protocolServer.getRemotingServer();

        ServiceRepository serviceRepository = ApplicationModel.getServiceRepository();
        ProviderModel providerModel = serviceRepository.lookupExportedService(url.getServiceKey());
        if (providerModel == null) {
            throw new IllegalStateException("Service " + url.getServiceKey() + "should have already been stored in service repository, " +
                    "but failed to find it.");
        }
        Object originalImpl = providerModel.getServiceInstance();

        Class<?> implClass = originalImpl.getClass();
        try {
            Method method = implClass.getMethod("setProxiedImpl", type);
            method.invoke(originalImpl, proxiedImpl);
        } catch (Exception e) {
            throw new IllegalStateException("Failed to set dubbo proxied service impl to stub, please make sure your stub " +
                    "was generated by the dubbo-protoc-compiler.", e);
        }
        grpcServer.getRegistry().addService((BindableService) originalImpl, url.getServiceKey());

        if (!grpcServer.isStarted()) {
            grpcServer.start();
        }

        return () -> grpcServer.getRegistry().removeService(url.getServiceKey());
    }

    @Override
    protected <T> Invoker<T> protocolBindingRefer(final Class<T> type, final URL url) throws RpcException {
        Class<?> enclosingClass = type.getEnclosingClass();

        if (enclosingClass == null) {
            throw new IllegalArgumentException(type.getName() + " must be declared inside protobuf generated classes, " +
                    "should be something like ServiceNameGrpc.IServiceName.");
        }

        final Method dubboStubMethod;
        try {
            dubboStubMethod = enclosingClass.getDeclaredMethod("getDubboStub", Channel.class, CallOptions.class,
                    URL.class, ReferenceConfigBase.class);
        } catch (NoSuchMethodException e) {
            throw new IllegalArgumentException("Does not find getDubboStub in " + enclosingClass.getName() + ", please use the customized protoc-gen-dubbo-java to update the generated classes.");
        }

        // Channel
        ReferenceCountManagedChannel channel = getSharedChannel(url);

        // CallOptions
        try {
            @SuppressWarnings("unchecked") final T stub = (T) dubboStubMethod.invoke(null,
                    channel,
                    GrpcOptionsUtils.buildCallOptions(url),
                    url,
                    ApplicationModel.getConsumerModel(url.getServiceKey()).getReferenceConfig()
            );
            final Invoker<T> target = proxyFactory.getInvoker(stub, type, url);
            GrpcInvoker<T> grpcInvoker = new GrpcInvoker<>(type, url, target, channel);
            invokers.add(grpcInvoker);
            return grpcInvoker;
        } catch (IllegalAccessException | InvocationTargetException e) {
            throw new IllegalStateException("Could not create stub through reflection.", e);
        }
    }

    /**
     * not used
     *
     * @param type
     * @param url
     * @param <T>
     * @return
     * @throws RpcException
     */
    @Override
    protected <T> T doRefer(Class<T> type, URL url) throws RpcException {
        throw new UnsupportedOperationException("not used");
    }

    /**
     * Get shared channel connection
     */
    private ReferenceCountManagedChannel getSharedChannel(URL url) {
        String key = url.getAddress();
        ReferenceCountManagedChannel channel = channelMap.get(key);

        if (channel != null && !channel.isTerminated()) {
            channel.incrementAndGetCount();
            return channel;
        }

        synchronized (lock) {
            channel = channelMap.get(key);
            // dubbo check
            if (channel != null && !channel.isTerminated()) {
                channel.incrementAndGetCount();
            } else {
                channel = new ReferenceCountManagedChannel(initChannel(url));
                channelMap.put(key, channel);
            }
        }

        return channel;
    }

    /**
     * Create new connection
     *
     * @param url
     */
    private ManagedChannel initChannel(URL url) {
        return GrpcOptionsUtils.buildManagedChannel(url);
    }

    @Override
    public int getDefaultPort() {
        return DEFAULT_PORT;
    }

    @Override
    public void destroy() {
        serverMap.values().forEach(ProtocolServer::close);
        channelMap.values().forEach(ReferenceCountManagedChannel::shutdown);
        serverMap.clear();
        channelMap.clear();
        super.destroy();
    }

    public class GrpcRemotingServer extends RemotingServerAdapter {

        private Server originalServer;
        private DubboHandlerRegistry handlerRegistry;
        private volatile boolean started;

        public GrpcRemotingServer(Server server, DubboHandlerRegistry handlerRegistry) {
            this.originalServer = server;
            this.handlerRegistry = handlerRegistry;
        }

        public void start() throws RpcException {
            try {
                originalServer.start();
                started = true;
            } catch (IOException e) {
                throw new RpcException("Starting gRPC server failed. ", e);
            }
        }

        public DubboHandlerRegistry getRegistry() {
            return handlerRegistry;
        }

        @Override
        public Object getDelegateServer() {
            return originalServer;
        }

        public boolean isStarted() {
            return started;
        }

        @Override
        public void close() {
            originalServer.shutdown();
        }
    }

}
